home *** CD-ROM | disk | FTP | other *** search
/ Language/OS - Multiplatform Resource Library / LANGUAGE OS.iso / cpp_libs / nihcl-30.lha / nihcl-3.0 / ex / Matrix.c < prev    next >
C/C++ Source or Header  |  1990-05-15  |  13KB  |  528 lines

  1. // Matrix.c -- matrix of type double
  2.  
  3. // $Header: /afs/alw.nih.gov/unix/sun4_40c/usr/local/src/nihcl-3.0/share/ex/RCS/Matrix.c,v 3.0 90/05/15 22:43:51 kgorlen Rel $
  4.  
  5. /*
  6. Author:
  7.     S. M. Orlow
  8.     Systex, Inc.
  9.     Beltsville, MD 20705
  10.     301-474-0111
  11.     sandy@alw.nih.gov
  12. */
  13.  
  14. #include <libc.h>
  15. #include <iostream.h>
  16.  
  17. #include "Matrix.h"
  18.  
  19. void Matrix::sizeError(char* where,const Matrix& m,int a, int b)
  20. {
  21.     cerr << where << ": " 
  22.          << m.nRow() << "x" << m.nCol() << " Matrix expected, "
  23.          << a << "x" << b << " found" << endl;
  24.     abort();
  25. }
  26. void Matrix::v_alloc(int a, int b)
  27. {
  28.     nrow = a; ncol = b;
  29. #ifdef TRACE
  30. cerr << "v_alloc: nrow=" << nrow << " ncol=" << ncol << endl;
  31. #endif
  32.     _v = new double[a*b];
  33. }
  34. Matrix::Matrix(int nr,int nc,double* f)
  35. {
  36.     v_alloc(nr,nc);
  37.  
  38.     int i,j;
  39.         if ( f==0 ) { // initilize 0 matrix
  40.       for (i=0; i<nrow; i++)
  41.         for (j=0; j<ncol; j++)
  42.           at(i,j) = 0;
  43.       }
  44.    else if ( nrow==1 ) { // initialize row
  45.       for (j=0; j<ncol; j++) at(0,j) = f[j];
  46.       }
  47.    else if ( ncol==1 ) { // initialize column
  48.       for (i=0; i<nrow; i++) at(i,0) = f[i];
  49.       }
  50.          else { // initialize matrix
  51.       for (i=0; i<nrow; i++)
  52.         for (j=0; j<ncol; j++)
  53.           at(i,j) = f[i*ncol+j];
  54.       }      
  55. }
  56. Matrix::Matrix(const Matrix& m)
  57. {
  58.     v_alloc(m.nrow,m.ncol);
  59.     *this = m;
  60. }
  61. Matrix::Matrix(const MatrixRow& v)
  62. {
  63.     v_alloc(1,v.nCol());
  64.     for (int j=0; j<ncol; j++) at(0,j) = v.at(j);
  65. }
  66. Matrix::Matrix(const MatrixCol& v)
  67. {
  68.     v_alloc(v.nRow(),1);
  69.     for (int i=0; i<nrow; i++) at(i,0) = v.at(i);
  70. }
  71. Matrix::Matrix(int k, double* f)
  72. {
  73.     v_alloc(k,k);
  74.     for (int i=0; i<nrow; i++)
  75.       for (int j=0; j<ncol; j++)
  76.         at(i,j) = (i==j)? f[i]:0;
  77. }
  78. Matrix::Matrix(int k, diagonal f)
  79. {
  80.     v_alloc(k,k);
  81.     for (int i=0; i<nrow; i++)
  82.       for (int j=0; j<ncol; j++)
  83.         at(i,j) = (i==j)? (double)f:0;
  84. }
  85. Matrix::~Matrix()
  86. {
  87.     delete _v;
  88. }
  89.  
  90. double& Matrix::operator()(int irow, int icol) const
  91. {
  92.     if ( irow<0||irow>=nrow||icol<0||icol>=ncol ) {
  93.       cerr << "at: [" 
  94.            << irow << "," << icol
  95.            << "] out of range [" 
  96.            << nrow << "," << ncol 
  97.            << "]" << endl;
  98.       abort();
  99.       }
  100.     return at(irow,icol);
  101. }
  102. MatrixRow Matrix::row(int k) const
  103. {
  104.     return MatrixRow(k,*this);
  105. }
  106. MatrixRow Matrix::row(int k,const MatrixRow& v) const
  107. {
  108.     for ( int i=0; i<ncol; i++ )
  109.       at(k,i) = v.at(i);
  110.     return v;
  111. }
  112. MatrixCol Matrix::col(int k) const
  113. {
  114.     return MatrixCol(k,*this);
  115. }
  116. MatrixCol Matrix::col(int k,const MatrixCol& v) const
  117. {
  118.     for ( int i=0; i<nrow; i++ )
  119.       at(i,k) = v.at(i);
  120.     return v;
  121. }
  122.  
  123. void Matrix::operator=(const Matrix& m)
  124. {
  125.     if ( ! sameSize(m.nrow,m.ncol) ) 
  126.           Matrix::sizeError("operator=",*this,m.nrow,m.ncol);
  127.     for (int i=0; i<nrow; i++)
  128.       for (int j=0; j<ncol; j++)
  129.         at(i,j) = m.at(i,j);
  130. }
  131. int Matrix::operator==(const Matrix& m) const
  132. {
  133.     if ( ! sameSize(m.nrow,m.ncol) ) return 0;
  134.     for (int i=0; i<nrow; i++)
  135.       for (int j=0; j<ncol; j++)
  136.         if ( at(i,j)!= m.at(i,j) ) return 0;
  137.     return 1;
  138. }
  139.  
  140. Matrix operator+(const Matrix& m,const Matrix& n)
  141. {
  142.     if ( ! m.sameSize(n.nRow(),n.nCol()) ) 
  143.        Matrix::sizeError("operator+",m,n.nRow(),n.nCol());
  144. // C++2.0 bug
  145. //    Matrix rm(m.nRow(),m.nCol());
  146.     int nr = m.nRow(), nc =m.nCol();
  147.     Matrix rm(nr,nc);
  148.     for (int i=0; i<m.nRow(); i++)
  149.       for (int j=0; j<m.nCol(); j++)
  150.         rm.at(i,j) = m.at(i,j)+n.at(i,j);
  151.     return rm;
  152. }
  153. Matrix operator-(const Matrix& m,const Matrix& n)
  154. {
  155.     if ( ! m.sameSize(n.nRow(),n.nCol()) ) 
  156.       Matrix::sizeError("operator-",m,n.nRow(),n.nCol());
  157. // C++2.0 bug
  158. //    Matrix rm(m.nRow(),m.nCol());
  159.     int nr = m.nRow(), nc =m.nCol();
  160.     Matrix rm(nr,nc);
  161.     for (int i=0; i<m.nRow(); i++)
  162.       for (int j=0; j<m.nCol(); j++)
  163.         rm.at(i,j) = m.at(i,j)-n.at(i,j);
  164.     return rm;
  165. }
  166. Matrix operator-(const Matrix& m)
  167. {
  168. // C++2.0 bug
  169. //    Matrix rm(m.nRow(),m.nCol());
  170.     int nr = m.nRow(), nc =m.nCol();
  171.     Matrix rm(nr,nc);
  172.     for (int i=0; i<m.nRow(); i++)
  173.       for (int j=0; j<m.nCol(); j++)
  174.         rm.at(i,j) = -m.at(i,j);
  175.     return rm;
  176. }
  177. double det(const Matrix& m)
  178. {
  179.     if (m.nRow()!=m.nCol()) {
  180.       cerr << "det: not a square matrix" << endl;;
  181.       abort();
  182.       }
  183.     if (m.nRow()==1) return m.at(0,0);
  184.     if (m.nRow()==2)
  185.       return m.at(0,0)*m.at(1,1)-m.at(0,1)*m.at(1,0);
  186.     if (m.nRow()==3) 
  187.       return (  m.at(0,0)*m.at(1,1)*m.at(2,2)
  188.                    -m.at(0,1)*m.at(1,2)*m.at(2,0)
  189.                    +m.at(0,2)*m.at(2,1)*m.at(0,1) );
  190.     if ( m.isUpperTriangle() ) {
  191.        double val = 1;
  192.        for (int i=0;i<m.nRow(); i++) val *=m.at(i,i);
  193.        return val;
  194.        }
  195.     double val = 0;
  196.     int sign = 1;
  197.     for(int j=0; j<m.nCol(); j++) {
  198.       val += sign*m.at(0,j)*det(m.coFactor(0,j));
  199.       sign *= -1;
  200.       }
  201.     return val;
  202. }
  203. double norm(const Matrix& m)
  204. {
  205.     double val =0;
  206.     for(int i=0; i<m.nRow(); i++)
  207.       for(int j=0; j<m.nCol(); j++)
  208.         if ( m.at(i,j)>val ) val = m.at(i,j);
  209.     return val;
  210. }
  211. Matrix operator*(const Matrix& m,const Matrix& n)
  212. {
  213.     if ( m.nCol()!=n.nRow() ) {
  214.          cerr << "operator*: " << m.nCol() <<
  215.            "x* Matrix expected " 
  216.            << n.nRow() << "x* found." << endl;
  217.          abort();
  218.        }
  219. // C++2.0 bug
  220. //    Matrix rm(m.nRow(),n.nCol());
  221.     int nr = m.nRow(), nc =n.nCol();
  222.     Matrix rm(nr,nc);
  223.     for (int i=0; i<rm.nRow(); i++)
  224.       for (int j=0; j<rm.nCol(); j++)
  225.         rm.at(i,j) = m.row(i)^n.col(j);
  226.     return rm;
  227. }
  228. Matrix operator&(const Matrix& m, const Matrix& n)
  229. {
  230.     if ( m.nRow()!=n.nRow() ) {
  231.        cerr << "operator&: " << m.nRow() << " rows expected, "
  232.         << n.nRow() << " found" << endl;
  233.        abort();
  234.        }
  235. // C++2.0 bug
  236. //    Matrix rm(m.nRow(),m.nCol()+n.nCol());
  237.     int nr = m.nRow(), nc =m.nCol()+n.nCol();
  238.     Matrix rm(nr,nc);
  239.     for (int i=0; i<m.nCol(); i++ )
  240.        rm.col(i,m.col(i));
  241.     for (int j=0; j<n.nCol(); j++ )
  242.        rm.col(m.nCol()+j,n.col(j));
  243.     return rm;      
  244. }
  245. Matrix Matrix::t() const
  246. {
  247.     Matrix rm(ncol,nrow);
  248.     for (int i=0; i<rm.nrow; i++)
  249.       for (int j=0; j<rm.ncol; j++)
  250.         rm.at(i,j) = at(j,i);
  251.     return rm;
  252. }
  253.  
  254. Matrix operator*(double f,const Matrix& m)
  255. {
  256. // C++2.0 bug
  257. //    Matrix rm(m.nRow(),m.nCol());
  258.     int nr = m.nRow(), nc =m.nCol();
  259.     Matrix rm(nr,nc);
  260.     for (int i=0; i<rm.nRow(); i++)
  261.       for (int j=0; j<rm.nCol(); j++)
  262.         rm.at(i,j) = f*m.at(i,j);
  263.     return rm;
  264. }
  265. void Matrix::operator*=(double f)
  266. {
  267.     for (int i=0; i<nrow; i++)
  268.       for (int j=0; j<ncol; j++)
  269.         at(i,j) *= f;
  270. }
  271. void Matrix::switchRows(int i,int j)
  272. {
  273.     Matrix tmp(row(i));
  274.     row(i,row(j));
  275.     row(j,tmp.row(0));    
  276. }
  277. void Matrix::combineRows(int i, double a, int j)
  278. {
  279.     for(int h=0; h<ncol; h++)
  280.       at(i,h) = at(i,h) + a*at(j,h);
  281. }
  282.  
  283. int Matrix::isUpperTriangle() const
  284. {
  285.     for(int j=0; j<ncol; j++)
  286.        for(int i=j+1; i<nrow; i++)
  287.          if ( at(i,j)!=0 ) return 0;
  288.     return 1;
  289. }
  290.  
  291. Matrix Matrix::upperTriangle()
  292. {
  293.     Matrix I(nrow,(diagonal)1);
  294.  
  295.     if ( this->isUpperTriangle() ) return I;
  296.  
  297.     for(int j=0; j<ncol; j++) {
  298.        int b_row = nrow-1;  // 1st non-zero entry from bottom
  299.        int t_row = j;       // 1st zero entry from the top
  300.  
  301.        // switch rows until all zeros are at
  302.        // the bottom of jTH column 
  303.        while ( b_row>=t_row ) { 
  304.          while (b_row>j&&at(b_row,j)==0) --b_row;
  305.          if ( b_row==j ) break; // bottom at diagonal
  306.          while (b_row>=t_row&&at(t_row,j)!=0) ++t_row;
  307.          if ( t_row==nRow() ) break; // top at last row
  308. #ifdef TRACE
  309. cerr << "switchRows(" << b_row << "," << t_row << ")" << endl;
  310. #endif
  311.          switchRows(b_row,t_row); 
  312.          I.switchRows(b_row,t_row);
  313.          }
  314. /*
  315.        // put maximum entry on the diagonal in jTH column
  316.        for(int h=0; h<j; h++)
  317.          if (at(h,j) > at(j,j)) {
  318.            switchRows(h,j); 
  319.            I.switchRows(h,j);
  320.            }        
  321. */
  322.        // now b_row is last non-zero entry from the top
  323.        // now t_row is first zero entry from the bottom
  324.        // combine until all entries below diagonal in jTH column =)
  325.        if ( b_row<=j ) continue;
  326.        for(int i=j+1; i<=b_row; i++) {
  327.           double f = -at(i,j)/at(j,j);
  328. #ifdef TRACE
  329. cerr << "combineRows(" << i << "," << f << "," << j << ")" << endl;
  330. #endif
  331.           combineRows(i,f,j);
  332.                I.combineRows(i,f,j);
  333.           }
  334.        }
  335.     return I;
  336. }
  337. Matrix Matrix::coFactor(int irow, int jcol) const
  338. {
  339.     if ( irow==1||jcol==1 ) {
  340.       cerr << "coFactor: can't coFactor row or column matrix" <<
  341.         endl;
  342.       abort();
  343.       }
  344.     Matrix val(nrow-1,ncol-1);
  345.     int getcol, getrow =0;
  346.     for(int i=0; i<val.nRow(); i++) {
  347.       if ( getrow==irow ) ++getrow;
  348.       if ( getrow==nrow ) break;
  349.       getcol = 0;
  350.       for(int j=0; j<val.nCol(); j++) {
  351.         if ( getcol==jcol ) ++getcol;
  352.         if ( getcol==ncol ) continue;
  353.         val.at(i,j) = at(getrow,getcol);
  354.         ++getcol;
  355.         }
  356.       ++getrow;
  357.       }
  358.     return val;
  359. }
  360. Matrix Matrix::coFactor(int irow, int jcol,Matrix& m) const
  361. {
  362.     if ( irow==1||jcol==1 ) {
  363.       cerr << "coFactor: can't coFactor row or column matrix" <<
  364.         endl;
  365.       abort();
  366.       }
  367.     if ( m.nRow()!=nrow-1||m.nCol()!=ncol-1 ) {
  368.       cerr << "coFactor: argument is wrong size" << endl;
  369.       abort();
  370.       } 
  371.     Matrix val = coFactor(irow,jcol);
  372.     int putcol, putrow =0;
  373.     for(int i=0; i<m.nRow(); i++) {
  374.       if ( putrow==irow ) ++putrow;
  375.       if ( putrow==nrow ) break;
  376.       putcol = 0;
  377.       for(int j=0; j<m.nCol(); j++) {
  378.         if ( putcol==jcol ) ++putcol;
  379.         if ( putcol==ncol ) continue;
  380.         at(putrow,putcol) = m.at(i,j);
  381.         ++putcol;
  382.         }
  383.       ++putrow;
  384.       }
  385.     return val;
  386. }
  387. Matrix Matrix::operator~() const
  388. // 1. triangulate: *this = ~P*T
  389. // 2. when T isUpperTriangle ~T isUpperTriangle
  390. // 3. split: T = T.row(0) + subT
  391. // 4. I = ~T*T
  392. // 5. ~T.at(0,0) = 1/T.at(0,0)
  393. // 6. sub~T = ~(subT)
  394. // 7. ~T.row(0) = [1/T.at(0,0)]&B
  395. //    where T.at(0,0)*B = [t21 ... t2n]*~subT
  396. // 8. ~*this = ~T*P
  397. {
  398.     if ( nrow!=ncol ) {
  399.        cerr << "operator~: can't invert a non-square matrix" <<
  400.          endl;
  401.        abort();
  402.        }
  403.     if ( nrow==1 ) {
  404.        Matrix T(1,1);
  405.        T.at(0,0) = 1/at(0,0);
  406.        return T;
  407.        }
  408.     Matrix T(*this);
  409. // 1. triangulate: *this = ~P*T
  410.     Matrix P(nrow,ncol);
  411.     P = T.upperTriangle();
  412.     if ( det(T)==0 ) {
  413.       cerr << "operator~: can't invert singular matrix" << endl;
  414.       abort();
  415.       }      
  416.  
  417. // 2. when T isUpperTriangle ~T isUpperTriangle
  418. // 3. split: T = T.row(0) + subT
  419.     Matrix& r = *new Matrix(1,ncol-1,&(T.at(0,1)));
  420.     Matrix& _subT = *new Matrix(~(T.coFactor(0,0)));
  421.     Matrix& B = *new Matrix(-(1/T.at(0,0))*r*_subT);
  422.     Matrix& val = *new Matrix(nrow,ncol);
  423.     val.at(0,0) = 1/T.at(0,0);
  424.     for(int i=1; i<ncol; i++) val.at(0,i)=B.at(0,i-1);
  425.     val.coFactor(0,0,_subT);
  426.     return val*P; // P is now the row-reduction transformation    
  427. }
  428. void Matrix::printOn(ostream& strm) const
  429. {
  430.     for(int i=0; i<nrow; i++ )
  431.        strm << row(i) << endl;
  432. }
  433. void Matrix::dumpOn(ostream& strm) const
  434. {
  435.     strm << "Matrix[" << nrow << " " << ncol << endl;
  436.     printOn(strm);
  437.     strm << "]" << endl;
  438. }
  439. MatrixRow::MatrixRow(const MatrixRow& r)
  440. {
  441.     pm = r.pm;
  442.     _row = r._row;
  443. }
  444. MatrixRow::MatrixRow(int k,const Matrix& m)
  445. {
  446.     pm = (Matrix*)&m;
  447.     _row = k;
  448. }
  449. double MatrixRow::operator^(const MatrixCol& c) const
  450. {
  451.     if ( nCol()!=c.nRow() ) {
  452.        cerr << "operator^: 1x" << nCol()
  453.         << " mismatched with " << c.nRow() 
  454.         << "x1." << endl;
  455.        abort();
  456.        }   
  457.     double val = 0;
  458.     for (int i=0; i< nCol(); i++ )
  459.       val += at(i)*c.at(i);
  460.     return val;    
  461. }
  462. void MatrixRow::operator=(const MatrixRow& r)
  463. {
  464.     if ( nCol()!=r.nCol() ) {
  465.       cerr << "operator=: MatrixRow of size " << nCol()
  466.            << " expected, size " << r.nCol() << " found." << endl,
  467.       abort();
  468.       }
  469.     pm = r.pm;
  470.     _row = r._row;
  471. }
  472. void MatrixRow::printOn(ostream& strm) const
  473. {
  474.     strm << "[ ";
  475.     for( int i=0; i<nCol(); i++ )
  476.       strm << at(i) << " ";
  477.     strm << "]";
  478. }
  479. MatrixCol::MatrixCol(const MatrixCol& c)
  480. {
  481.     pm = c.pm;
  482.     _col = c._col;
  483. }
  484. MatrixCol::MatrixCol(int k, const Matrix& m)
  485. {
  486.     pm = (Matrix*)&m;
  487.     _col = k;
  488. }
  489. void MatrixCol::operator=(const MatrixCol& c)
  490. {
  491.     if ( nRow()!=c.nRow() ) {
  492.       cerr << "operator=: MatrixCol of size " 
  493.            << nRow() << " expected, size " 
  494.            << c.nRow() << " found." << endl;
  495.       abort();
  496.       }
  497.     pm = c.pm;
  498.     _col = c._col;
  499. }
  500. void MatrixCol::operator+=(const MatrixCol& c)
  501. {
  502.     for(int i=0; i<nRow(); i++ )
  503.       at(i) += c.at(i);
  504. }
  505. void MatrixCol::printOn(ostream& strm) const
  506. {
  507.     strm << "t[ ";
  508.     for( int i=0; i<nRow(); i++ )
  509.       strm << at(i) << " ";
  510.     strm << "]";
  511. }
  512.  
  513. ostream& operator<<(ostream& strm,const Matrix& m)
  514. {
  515.     m.printOn(strm);
  516.     return strm;
  517. }
  518. ostream& operator<<(ostream& strm,const MatrixRow& m)
  519. {
  520.     m.printOn(strm);
  521.     return strm;
  522. }
  523. ostream& operator<<(ostream& strm,const MatrixCol& m)
  524. {
  525.     m.printOn(strm);
  526.     return strm;
  527. }
  528.